{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "RL_ActorCritic_DDPG_Movie_Recommendation.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/bcsrn/RL_DDPG_Recommendation/blob/master/RL_ActorCritic_DDPG_Movie_Recommendation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DsEJ5r1AhOLK",
        "colab_type": "text"
      },
      "source": [
        "#Data Load, Env Setup\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kk41_9F5Dwoi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Mounting the drive\n",
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r1NmejTIFPtf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Dependencies\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "from scipy.sparse.linalg import svds\n",
        "from sklearn.model_selection import train_test_split\n",
        "import tensorflow as tf\n",
        "from torch.utils.data import DataLoader\n",
        "import itertools\n",
        "import torch\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b8aPaxXpNsz7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Loading datasets\n",
        "ratings_list = [i.strip().split(\"::\") for i in open('/content/gdrive/My Drive/RLProject/Data/ml-1m/ratings.dat', 'r').readlines()]\n",
        "users_list = [i.strip().split(\"::\") for i in open('/content/gdrive/My Drive/RLProject/Data/ml-1m/users.dat', 'r').readlines()]\n",
        "movies_list = [i.strip().split(\"::\") for i in open('/content/gdrive/My Drive/RLProject/Data/ml-1m/movies.dat',encoding='latin-1').readlines()]\n",
        "ratings_df = pd.DataFrame(ratings_list, columns = ['UserID', 'MovieID', 'Rating', 'Timestamp'], dtype = int)\n",
        "movies_df = pd.DataFrame(movies_list, columns = ['MovieID', 'Title', 'Genres'])\n",
        "movies_df['MovieID'] = movies_df['MovieID'].apply(pd.to_numeric)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bSqs6TACUjCJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "R_df = ratings_df.pivot(index = 'UserID', columns ='MovieID', values = 'Rating').fillna(0)\n",
        "R_df = R_df.astype(int)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iFBmC7b9Wls-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#might be used in the user dependednt state representation\n",
        "userids = list(R_df.index.values) #list of userids\n",
        "idx_to_userids = {i:userids[i] for i in range(len(userids))}\n",
        "userids_to_idx = {userids[i]:i for i in range(len(userids))}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m2b0pDxe-Kng",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#list of movie ids\n",
        "columns = list(R_df)\n",
        "idx_to_id = {i:columns[i] for i in range(len(columns))}\n",
        "id_to_idx = {columns[i]:i for i in range(len(columns))}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FISA0DvT8fn5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#to get item embeddings\n",
        "#R_df[userid][movieid]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_RQ7Rjif5-dh",
        "colab_type": "text"
      },
      "source": [
        "##Getting Embeddings of User and Item(Movie Id's)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9CmSP_wmg5-U",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "R = R_df.values\n",
        "user_ratings_mean = np.mean(R, axis = 1)\n",
        "R_demeaned = R - user_ratings_mean.reshape(-1, 1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sTMwltBwigpE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Movie Embeddings\n",
        "U, sigma, Vt = svds(R_demeaned, k = 100)\n",
        "# print(Vt.shape)\n",
        "V = Vt.transpose()\n",
        "# print(V.shape)\n",
        "movie_list = V.tolist()\n",
        "movie_embeddings_dict = {columns[i]:tf.convert_to_tensor(movie_list[i]) for i in range(len(columns))}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WAjBAVKZgXc5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "user_list = U.tolist()\n",
        "user_embeddings_dict =  {userids[i]:tf.convert_to_tensor(user_list[i]) for i in range(len(userids))}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qxFmvK6QBBGi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#prepare_dataset\n",
        "#using ratings_df\n",
        "users_df = ratings_df.sort_values([\"UserID\",\"Timestamp\"]).set_index(\"UserID\").fillna(0).drop(\"Timestamp\",axis=1)\n",
        "users = dict(tuple(users_df.groupby(\"UserID\")))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bcI7Tn4h8cok",
        "colab_type": "text"
      },
      "source": [
        "##Train and Test Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ayry5fcBNOhK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Getting users with positive rating count greater than 10\n",
        "# 10 == 'N' positively interacted items\n",
        "from collections import defaultdict\n",
        "from collections import Counter\n",
        "users_dict = defaultdict(dict)\n",
        "users_id_list = set()\n",
        "for user_id in users:\n",
        "  rating_freq = Counter(users[user_id][\"Rating\"].values)\n",
        "  if rating_freq['4']+rating_freq['5']<10 :\n",
        "    continue\n",
        "  else:\n",
        "    users_id_list.add(user_id)\n",
        "    users_dict[user_id][\"item\"] = users[user_id][\"MovieID\"].values\n",
        "    users_dict[user_id][\"rating\"] = users[user_id][\"Rating\"].values\n",
        "  \n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x5Qu67dNVItq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "users_id_list = np.array(list(users_id_list))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LmQNj6Vu10EU",
        "colab_type": "code",
        "outputId": "27db77f6-1f9e-4875-a566-60d13841c7f3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "#choosing default train_test_split of 25%\n",
        "train_users,test_users = train_test_split(users_id_list)\n",
        "print(train_users[:2])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['1547' '2723']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x1Ldqt8vhOPI",
        "colab_type": "code",
        "outputId": "c9d5bb5e-2830-4bca-e017-5347b4f08425",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "print(test_users[:2])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['2122' '3364']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P1yOD6JA32u8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from torch.utils.data import Dataset\n",
        "class UserDataset(Dataset):\n",
        "  def __init__(self,users_list,users_dict):\n",
        "    self.users_list = users_list\n",
        "    self.users_dict = users_dict\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.users_list)\n",
        "\n",
        "  def __getitem__(self,idx):\n",
        "    user_id = self.users_list[idx]\n",
        "    items = [('1',)]*10\n",
        "    ratings = [('0',)]*10\n",
        "    j=0\n",
        "    for i,rate in enumerate(self.users_dict[user_id][\"rating\"]):\n",
        "      if int(rate) >3 and j < 10:\n",
        "        items[j] = self.users_dict[user_id][\"item\"][i]\n",
        "        ratings[j] = self.users_dict[user_id][\"rating\"][i]\n",
        "        j += 1\n",
        "    # item = list(self.users_dict[user_id][\"item\"][:])\n",
        "    # rating = list(self.users_dict[user_id][\"rating\"][:])\n",
        "    size = len(items)\n",
        "    \n",
        "    return {'item':items,'rating':ratings,'size':size,'userid':user_id,'idx':idx}\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wfp3TcK573H0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_users_dataset = UserDataset(train_users,users_dict)\n",
        "test_users_dataset = UserDataset(test_users,users_dict)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4X4OvMdoElSw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_dataloader = DataLoader(train_users_dataset,batch_size=1)\n",
        "test_dataloader = DataLoader(test_users_dataset,batch_size=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "njrztpHpPMYQ",
        "colab_type": "code",
        "outputId": "329604b7-861b-4509-d001-a1d8e21f2b02",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "train_num = len(train_dataloader)\n",
        "print(train_num)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "4462\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "so-BkRxW8SPE",
        "colab_type": "text"
      },
      "source": [
        "#State Representation Models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aULb7Dxhua6S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def drrave_state_rep(userid_b,items,memory,idx):\n",
        "  user_num = idx\n",
        "  H = [] #item embeddings\n",
        "  user_n_items = items\n",
        "  user_embeddings = torch.Tensor(np.array(user_embeddings_dict[userid_b[0]]),).unsqueeze(0)\n",
        "  for i,item in enumerate(user_n_items):\n",
        "    H.append(np.array(movie_embeddings_dict[item[0]]))\n",
        "  avg_layer = nn.AvgPool1d(1)\n",
        "  item_embeddings = avg_layer(torch.Tensor(H,).unsqueeze(0)).permute(0,2,1).squeeze(0)\n",
        "  state = torch.cat([user_embeddings,user_embeddings*item_embeddings.T,item_embeddings.T])\n",
        "  return state #state tensor shape [21,100]\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lCkdGnYdfon8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def drru_state_rep(userid_b,items,memory,idx):\n",
        "  user_num = idx\n",
        "  H = []\n",
        "  user_n_items = items\n",
        "  user_embeddings = user_embeddings_dict[userid_b[0]]\n",
        "  for i,item in enumerate(user_n_items):\n",
        "    ui = np.array(user_embeddings) * np.array(movie_embeddings_dict[item[0]])\n",
        "    H.append(ui)\n",
        "\n",
        "  pairs = list(itertools.combinations(memory[user_num], 2))\n",
        "  for item1,item2 in pairs:\n",
        "    pair1 =  np.array(movie_embeddings_dict[str(int(item1))])\n",
        "    pair2 = np.array(movie_embeddings_dict[str(int(item2))])\n",
        "\n",
        "    product = pair1*pair2\n",
        "    H.append(product)\n",
        "  state = torch.Tensor(H,)\n",
        "  return state #state tensor shape [55,100]\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xSS7Jg83YWDS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def drrp_state_rep(items,memory,idx):\n",
        "  user_num = idx\n",
        "  H = []\n",
        "  user_n_items = items\n",
        "  for i,item in enumerate(user_n_items):\n",
        "    H.append(np.array(movie_embeddings_dict[item[0]]))\n",
        "  \n",
        "  pairs = list(itertools.combinations(memory[user_num], 2))\n",
        "  for item1,item2 in pairs:\n",
        "    pair1 =  np.array(movie_embeddings_dict[str(int(item1))])\n",
        "    pair2 = np.array(movie_embeddings_dict[str(int(item2))])\n",
        "    product = pair1*pair2\n",
        "    H.append(product)\n",
        "  state = torch.Tensor(H,)\n",
        "  return state"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bKvM1H0FO1zg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# just n items and their embeddings used to represent state\n",
        "def state_rep(item_b):\n",
        "  state = []\n",
        "  user_embeddings = np.zeros((len(columns),100))\n",
        "  movie_ids = list(item[0] for item in item_b)\n",
        "  for i,subitem in enumerate(user_embeddings):\n",
        "    if idx_to_id[i] in movie_ids:\n",
        "      user_embeddings[i] = np.array(movie_embeddings_dict[idx_to_id[i]])\n",
        "    else:\n",
        "      user_embeddings[i] = np.zeros((100,))\n",
        "  state = torch.Tensor(user_embeddings,)\n",
        "  return torch.reshape(state,[-1]) "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BnGewmmshbCu",
        "colab_type": "text"
      },
      "source": [
        "#Actor, Critic Module"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k4kiw_L7sypy",
        "colab_type": "code",
        "outputId": "4753b580-14d9-45e5-e931-cdbf4392618e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 331
        }
      },
      "source": [
        "!pip install git+https://github.com/pabloppp/pytorch-tools@0.2.4 -U"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting git+https://github.com/pabloppp/pytorch-tools@0.2.4\n",
            "  Cloning https://github.com/pabloppp/pytorch-tools (to revision 0.2.4) to /tmp/pip-req-build-ryzkfew0\n",
            "  Running command git clone -q https://github.com/pabloppp/pytorch-tools /tmp/pip-req-build-ryzkfew0\n",
            "  Running command git checkout -q 86c73996537002ab29e7e40f925cb90756f58156\n",
            "Requirement already satisfied, skipping upgrade: torch==1.* in /usr/local/lib/python3.6/dist-packages (from torchtools==0.2.4) (1.4.0)\n",
            "Requirement already satisfied, skipping upgrade: torchvision in /usr/local/lib/python3.6/dist-packages (from torchtools==0.2.4) (0.5.0)\n",
            "Requirement already satisfied, skipping upgrade: numpy==1.* in /usr/local/lib/python3.6/dist-packages (from torchtools==0.2.4) (1.18.3)\n",
            "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from torchvision->torchtools==0.2.4) (1.12.0)\n",
            "Requirement already satisfied, skipping upgrade: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->torchtools==0.2.4) (7.0.0)\n",
            "Building wheels for collected packages: torchtools\n",
            "  Building wheel for torchtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for torchtools: filename=torchtools-0.2.4-cp36-none-any.whl size=19201 sha256=b01aa5dbbdc0fd333823de535e912a8abdc3f2d9f313c49f2960f9d2b140abb9\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-tsyf4x7c/wheels/98/e6/12/fbea7d7f60c85f3eeca5f5253cf77eab1a4fd352ea495ff3b9\n",
            "Successfully built torchtools\n",
            "Installing collected packages: torchtools\n",
            "Successfully installed torchtools-0.2.4\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1nZWN38whlMJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Dependencies\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torchtools.optim import Ranger\n",
        "import tqdm\n",
        "import random\n",
        "import matplotlib.pyplot as plt\n",
        "#uncomment to use adam\n",
        "# from torch.optim import Adam\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r3YYQPNuheGY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Actor Model:\n",
        "#Generating an action a based on state s\n",
        "\n",
        "class Actor(torch.nn.Module):\n",
        "  def __init__(self, input_dim, output_dim,hidden_dim):\n",
        "    super(Actor, self).__init__()\n",
        "\n",
        "    self.drop_layer = nn.Dropout(p=0.5)        \n",
        "    self.linear1 = nn.Linear(input_dim, hidden_dim)\n",
        "    self.linear2 = nn.Linear(hidden_dim, hidden_dim)\n",
        "    self.linear3 = nn.Linear(hidden_dim, output_dim)\n",
        "  \n",
        "\n",
        "  def forward(self, state):\n",
        "    # state = self.state_rep(state)\n",
        "    x = F.relu(self.linear1(state))\n",
        "    # print(x.shape)\n",
        "    x = self.drop_layer(x)\n",
        "    x = F.relu(self.linear2(x))\n",
        "    # print(x.shape)\n",
        "    x = self.drop_layer(x)\n",
        "    # x = torch.tanh(self.linear3(x)) # in case embeds are -1 1 normalized\n",
        "    x = self.linear3(x) # in case embeds are standard scaled / wiped using PCA whitening\n",
        "    # return state, x\n",
        "    return x\n",
        "      \n",
        "\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_6Ss5R4XoKii",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Critic(nn.Module):\n",
        "  def __init__(self,input_dim,output_dim,hidden_dim):\n",
        "\n",
        "    super(Critic, self).__init__()\n",
        "        \n",
        "    self.drop_layer = nn.Dropout(p=0.5)\n",
        "    \n",
        "    self.linear1 = nn.Linear(input_dim + output_dim, hidden_dim)\n",
        "    self.linear2 = nn.Linear(hidden_dim, hidden_dim)\n",
        "    self.linear3 = nn.Linear(hidden_dim, 1)\n",
        "\n",
        "  def forward(self,state,action):    \n",
        "    x = torch.cat([state, action], 1)\n",
        "    # print(x.shape)\n",
        "    x = F.relu(self.linear1(x))\n",
        "    x = self.drop_layer(x)\n",
        "    x = F.relu(self.linear2(x))\n",
        "    x = self.drop_layer(x)\n",
        "    x = self.linear3(x)\n",
        "    return x\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nE3PBjinMmKZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class ReplayBuffer:\n",
        "    def __init__(self, capacity):\n",
        "        self.capacity = capacity\n",
        "        self.buffer = []\n",
        "        self.position = 0\n",
        "    \n",
        "    def push(self, state, action, reward, next_state):\n",
        "        if len(self.buffer) < self.capacity:\n",
        "            self.buffer.append(None)\n",
        "        self.buffer[self.position] = (state, action, reward, next_state)\n",
        "        self.position = (self.position + 1) % self.capacity\n",
        "    \n",
        "    def sample(self, batch_size):\n",
        "        batch = random.sample(self.buffer, batch_size)\n",
        "        # print(batch)\n",
        "        state, action, reward, next_state = map(np.stack,zip(*batch))\n",
        "        return state, action, reward, next_state\n",
        "    \n",
        "    def __len__(self):\n",
        "        return len(self.buffer)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hWWYGkm2JMhB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "device = 'cpu'\n",
        "# cuda = torch.device('cuda')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ft471pD0KoKT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#used for plotting purposes\n",
        "p_loss = []\n",
        "v_loss = []"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zMNI4edhzv_6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def ddpg_update(batch_size=1, \n",
        "                gamma = 0.6,\n",
        "                min_value=-np.inf,\n",
        "                max_value=np.inf,\n",
        "                soft_tau=1e-2):\n",
        "    \n",
        "    state, action, reward, next_state = replay_buffer.sample(batch_size)\n",
        "    state      = torch.FloatTensor(state).to(device)\n",
        "\n",
        "    next_state = torch.FloatTensor(next_state).to(device)\n",
        "\n",
        "    action     = torch.FloatTensor(action).to(device)\n",
        "    reward     = torch.FloatTensor(reward).to(device)\n",
        "    # print(state.shape)\n",
        "    # print(policy_net(state).shape)\n",
        "    policy_loss = value_net(state, policy_net(state))\n",
        "    policy_loss = -policy_loss.mean()\n",
        "    p_loss.append(policy_loss)\n",
        "    next_action    = target_policy_net(next_state)\n",
        "    target_value   = target_value_net(next_state, next_action.detach())\n",
        "    expected_value = reward + gamma * target_value\n",
        "    expected_value = torch.clamp(expected_value, min_value, max_value)\n",
        "\n",
        "    value = value_net(state, action)\n",
        "    # print(\"1\")\n",
        "    value_loss = value_criterion(value, expected_value.detach())\n",
        "    # print(\"2\")\n",
        "    v_loss.append(value_loss)\n",
        "    policy_optimizer.zero_grad()\n",
        "    # print(\"3\")\n",
        "    policy_loss.backward()\n",
        "    policy_optimizer.step()\n",
        "\n",
        "    value_optimizer.zero_grad()\n",
        "    value_loss.backward()\n",
        "    value_optimizer.step()\n",
        "\n",
        "    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):\n",
        "                target_param.data.copy_(\n",
        "                    target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
        "                )\n",
        "\n",
        "    for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):\n",
        "            target_param.data.copy_(\n",
        "                target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
        "            )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RNlkmYxAotps",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#initializing actor and critic networks for drru and drrp state representation\n",
        "\n",
        "value_net = Critic(5500,100,256)\n",
        "policy_net = Actor(5500,100,256)\n",
        "\n",
        "target_value_net = Critic(5500,100,256)\n",
        "target_policy_net = Actor(5500,100,256)\n",
        "\n",
        "\n",
        "target_policy_net.eval()\n",
        "target_value_net.eval()\n",
        "\n",
        "for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):\n",
        "  target_param.data.copy_(param.data)\n",
        "\n",
        "for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):\n",
        "  target_param.data.copy_(param.data)\n",
        "\n",
        "value_criterion = nn.MSELoss()\n",
        "value_optimizer      = Ranger(value_net.parameters(),  lr=1e-4)\n",
        "policy_optimizer     = Ranger(policy_net.parameters(), lr=1e-4)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HjhPAD9r75Q0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#initializing for drrave state representation\n",
        "\n",
        "value_net = Critic(2100,100,256)\n",
        "policy_net = Actor(2100,100,256)\n",
        "\n",
        "target_value_net = Critic(2100,100,256)\n",
        "target_policy_net = Actor(2100,100,256)\n",
        "\n",
        "\n",
        "target_policy_net.eval()\n",
        "target_value_net.eval()\n",
        "\n",
        "for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):\n",
        "  target_param.data.copy_(param.data)\n",
        "\n",
        "for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):\n",
        "  target_param.data.copy_(param.data)\n",
        "\n",
        "value_criterion = nn.MSELoss()\n",
        "value_optimizer      = Ranger(value_net.parameters(),  lr=1e-4)\n",
        "policy_optimizer     = Ranger(policy_net.parameters(), lr=1e-4)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oq0XEAKo-Yhi",
        "colab_type": "code",
        "outputId": "623ca5dc-48f3-42ac-87c5-238093d93161",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 127
        }
      },
      "source": [
        "print(policy_net)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Actor(\n",
            "  (drop_layer): Dropout(p=0.5, inplace=False)\n",
            "  (linear1): Linear(in_features=5500, out_features=256, bias=True)\n",
            "  (linear2): Linear(in_features=256, out_features=256, bias=True)\n",
            "  (linear3): Linear(in_features=256, out_features=100, bias=True)\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WZmfxcz9Mmmy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "replay_buffer_size = 10000\n",
        "replay_buffer = ReplayBuffer(replay_buffer_size)\n",
        "\n",
        "memory = np.ones((train_num,10))*-1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y2IHzJraDgHn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_action(state,action_emb,userid_b,item_b,preds):\n",
        "  action_emb = torch.reshape(action_emb,[1,100]).unsqueeze(0)\n",
        "  item_embedding = []\n",
        "  for movie in users_dict[userid_b[0]][\"item\"]:  \n",
        "    item_embedding.append(np.array(movie_embeddings_dict[movie]))\n",
        "  item_embedding = torch.Tensor(item_embedding,)\n",
        "  items = item_embedding.T.unsqueeze(0)\n",
        "  m = torch.bmm(action_emb,items).squeeze(0)\n",
        "  sorted_m,indices = torch.sort(m,descending=True)\n",
        "  index_list = list(indices[0])\n",
        "  for i in index_list:\n",
        "    if users_dict[userid_b[0]][\"item\"][i] not in preds:    \n",
        "      preds.add(users_dict[userid_b[0]][\"item\"][i])\n",
        "      return int(i)\n",
        " "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aMOxsJ4Yiyj3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def update_memory(memory,action,idx):\n",
        "  memory[idx] = list(memory[idx,1:])+[action]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VkpK8GRf70ui",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "rate = 0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kztx5pqv7r-p",
        "colab_type": "text"
      },
      "source": [
        "#Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ptAHnMaMz5fR",
        "colab_type": "code",
        "outputId": "b2509e9e-8dba-4b7a-fbf9-10dce9e2ee9b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "preddict = dict()\n",
        "it = iter(train_dataloader)\n",
        "for episode in tqdm.tqdm(range(train_num-1)):    \n",
        "  batch_size= 1\n",
        "  preds = set()\n",
        "  first = next(it)\n",
        "  item_b,rating_b,size_b,userid_b,idx_b = first['item'],first['rating'],first['size'],first['userid'],first['idx']\n",
        "  memory[idx_b] = [item[0] for item in item_b]\n",
        "  state = drrave_state_rep(userid_b,item_b,memory,idx_b)\n",
        "  for j in range(5):    \n",
        "    state_rep =  torch.reshape(state,[-1])\n",
        "    action_emb = policy_net(state_rep)\n",
        "    action = get_action(state,action_emb,userid_b,item_b,preds)\n",
        "    rate = int(users_dict[userid_b[0]][\"rating\"][action])\n",
        "    try:\n",
        "      ratings = (int(rate)-3)/2\n",
        "    except:\n",
        "      ratings = 0\n",
        "    reward = torch.Tensor((ratings,))\n",
        "\n",
        "    if reward > 0:\n",
        "      update_memory(memory,int(users_dict[userid_b[0]][\"item\"][action]),idx_b)\n",
        "\n",
        "    next_state = drrave_state_rep(userid_b,item_b,memory,idx_b)\n",
        "    next_state_rep = torch.reshape(next_state,[-1])\n",
        "    replay_buffer.push(state_rep.detach().cpu().numpy(), action_emb.detach().cpu().numpy(), reward, next_state_rep.detach().cpu().numpy())\n",
        "    if len(replay_buffer) > batch_size:\n",
        "        ddpg_update()\n",
        "\n",
        "    state = next_state\n",
        "  preddict[userid_b[0]] = preds\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 4461/4461 [07:43<00:00,  9.63it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-KxvxcIJOTAM",
        "colab_type": "code",
        "outputId": "3a8386cb-8954-4f13-b7e8-59c53c7ab239",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 285
        }
      },
      "source": [
        "plt.plot(v_loss)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[<matplotlib.lines.Line2D at 0x7f0bcb633c18>]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 123
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD6CAYAAAC/KwBlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dfZxV1X3v8c8XEK4FiSgjoTwUVJJeNQ3q3Mi9bfOysVGkN8W0aYptA02tNDd6m9y0t8F6b7VJTE1aY2uamGKlQm4qIZoUXhGDlBitSRAHRRCfGJ6ECc8gDyIPM/O7f5w1uGc8Z88wc2bOzJnv+/U6r7PPb6+199p7zsxv9l5r762IwMzMrJQBlW6AmZn1bk4UZmaWy4nCzMxyOVGYmVkuJwozM8vlRGFmZrnaTRSSxkl6XNKLktZL+lSKnyNpuaQN6X1EikvSPZLqJa2VdFlmWbNS+Q2SZmXil0tal+rcI0l56zAzs56j9q6jkDQaGB0Rz0o6C1gNXAf8IbA/Iu6UNAcYERGflTQN+J/ANOAK4B8i4gpJ5wB1QC0QaTmXR8QBSauAPwWeBpYC90TEo5K+XGwdee0dOXJkTJgwoXN7w8ysn1q9evXeiKgpNm9Qe5UjYgewI00flvQSMAaYDlyZis0HfgR8NsUXRCEDrZR0dko2VwLLI2I/gKTlwFRJPwKGR8TKFF9AIRE9mrOOkiZMmEBdXV17m2VmZhmStpaad1p9FJImAJdS+M9/VEoiADuBUWl6DLAtU217iuXFtxeJk7MOMzPrIR1OFJKGAQ8Dn46IQ9l56eihW+8FkrcOSbMl1Umq27NnT3c2w8ys3+lQopB0BoUk8a2I+G4K70qnlFr6MXaneAMwLlN9bIrlxccWieeto5WImBsRtRFRW1NT9BSbmZl1UkdGPQm4H3gpIr6SmbUEaBm5NAtYnInPTKOfpgAH0+mjZcDVkkak0UtXA8vSvEOSpqR1zWyzrGLrMDOzHtJuZzbwy8DHgHWS1qTYXwJ3Aosk3QBsBT6a5i2lMOKpHjgKfBwgIvZL+jzwTCr3uZaObeCTwAPAmRQ6sR9N8VLrMDOzHtLu8Ni+pra2Njzqyczs9EhaHRG1xeb5ymwzM8vlRFFlmpuDRXXbONnUXOmmmFmVcKKoMg89u52/eGgtc5/cVOmmmFmVcKKoMgePngTgwBsnKtwSM6sWThRmZpbLicLMzHI5UZiZWS4nCjMzy+VEYWZmuZwoqlR1XW9vZpXkRFFlCg+RNTMrHycKMzPL5URhZma5nCjMzCyXE4WZmeVyojAzs1xOFFWqyp5HZWYV1JFnZs+TtFvSC5nYtyWtSa8tLY9IlTRB0puZed/I1Llc0jpJ9ZLuSc/HRtI5kpZL2pDeR6S4Url6SWslXVb+zTczs/Z05IjiAWBqNhARvxsRkyNiMvAw8N3M7I0t8yLiE5n4vcCNwKT0alnmHGBFREwCVqTPANdmys5O9c3MrIe1mygi4klgf7F56ajgo8CDecuQNBoYHhEro/CQ7gXAdWn2dGB+mp7fJr4gClYCZ6flmJlZD+pqH8WvArsiYkMmNlHSc5KekPSrKTYG2J4psz3FAEZFxI40vRMYlamzrUSdViTNllQnqW7Pnj1d2BwzM2urq4nielofTewAxkfEpcBngH+VNLyjC0tHG6fdDRsRcyOiNiJqa2pqTrd6VfKtPMysXAZ1tqKkQcBvAZe3xCLiOHA8Ta+WtBF4F9AAjM1UH5tiALskjY6IHenU0u4UbwDGlahjZmY9pCtHFL8OvBwRp04pSaqRNDBNn0+hI3pTOrV0SNKU1K8xE1icqi0BZqXpWW3iM9PopynAwcwpKmuHh8eaWbl0ZHjsg8BPgXdL2i7phjRrBm/vxH4/sDYNl30I+EREtHSEfxL4Z6Ae2Ag8muJ3Ah+UtIFC8rkzxZcCm1L5+1J9MzPrYe2eeoqI60vE/7BI7GEKw2WLla8DLikS3wdcVSQewE3ttc/MzLqXr8w2M7NcThRmZpbLicLMzHI5UZiZWS4niioVp3/doplZUU4UVUa+JNvMysyJwszMcjlRmJlZLicKMzPL5URhZma5nCjMzCyXE0WV8t1jzaxcnCjMzCyXE0WV8uUUZlYuThRmZpbLicLMzHJ15Al38yTtlvRCJna7pAZJa9JrWmbeLZLqJb0i6ZpMfGqK1Uuak4lPlPR0in9b0uAUH5I+16f5E8q10WZm1nEdOaJ4AJhaJH53RExOr6UAki6i8IjUi1Odr0samJ6j/TXgWuAi4PpUFuBLaVkXAgeAlket3gAcSPG7UzkzM+th7SaKiHgS2N9euWQ6sDAijkfEZgrPu35fetVHxKaIOAEsBKarcAe7D1B4vjbAfOC6zLLmp+mHgKvkO951mIfHmlm5dKWP4mZJa9OpqREpNgbYlimzPcVKxc8FXo+IxjbxVstK8w+m8pbDmdTMyq2zieJe4AJgMrADuKtsLeoESbMl1Umq27NnTyWbYmZWdTqVKCJiV0Q0RUQzcB+FU0sADcC4TNGxKVYqvg84W9KgNvFWy0rz35HKF2vP3IiojYjampqazmySmZmV0KlEIWl05uOHgZYRUUuAGWnE0kRgErAKeAaYlEY4DabQ4b0kIgJ4HPhIqj8LWJxZ1qw0/RHgh6m8mZn1oEHtFZD0IHAlMFLSduA24EpJk4EAtgB/AhAR6yUtAl4EGoGbIqIpLedmYBkwEJgXEevTKj4LLJT0BeA54P4Uvx/4pqR6Cp3pM7q8tWZmdtraTRQRcX2R8P1FYi3l7wDuKBJfCiwtEt/EW6eusvFjwO+01z4zM+tevjLbzMxyOVFUGV9pYmbl5kRhZma5nCjMzCyXE4WZmeVyoqgyvtLEzMrNiaIf+enGfUyY8wg7Dx6rdFPMrA9xouhHvrlyCwCrtx6obEPMrE9xoqgyHh5rZuXmRGFmZrmcKMzMLJcThZmZ5XKiMDOzXE4UZmaWy4miSvkZT2ZWLk4UVcajY82s3Jwo+qHARxtm1nHtJgpJ8yTtlvRCJva3kl6WtFbS9ySdneITJL0paU16fSNT53JJ6yTVS7pHKlwaJukcScslbUjvI1JcqVx9Ws9l5d/8/kU+3jCzTujIEcUDwNQ2seXAJRHxS8CrwC2ZeRsjYnJ6fSITvxe4EZiUXi3LnAOsiIhJwIr0GeDaTNnZqb6ZmfWwdhNFRDwJ7G8TeywiGtPHlcDYvGVIGg0Mj4iVUehlXQBcl2ZPB+an6flt4guiYCVwdlqOmZn1oHL0UfwR8Gjm80RJz0l6QtKvptgYYHumzPYUAxgVETvS9E5gVKbOthJ1zMyshwzqSmVJtwKNwLdSaAcwPiL2Sboc+DdJF3d0eRERkk67p1XSbAqnpxg/fvzpVq9K7q42s3Lp9BGFpD8E/jvw++l0EhFxPCL2penVwEbgXUADrU9PjU0xgF0tp5TS++4UbwDGlajTSkTMjYjaiKitqanp7CZVBfn2sWZWZp1KFJKmAn8B/GZEHM3EayQNTNPnU+iI3pROLR2SNCWNdpoJLE7VlgCz0vSsNvGZafTTFOBg5hSVleAL7cys3No99STpQeBKYKSk7cBtFEY5DQGWp/9gV6YRTu8HPifpJNAMfCIiWjrCP0lhBNWZFPo0Wvo17gQWSboB2Ap8NMWXAtOAeuAo8PGubGh/4+MKMyuXdhNFRFxfJHx/ibIPAw+XmFcHXFIkvg+4qkg8gJvaa5+ZmXUvX5ltZma5nCiq1N43TlC/+0jRee7GMLPT4URRpR5Zu4Nf/8oTrYPuuDCzTnCiqDIeHmtm5eZEYWZmuZwozMwslxOFmZnlcqIwM7NcThRmZpbLicLMzHI5UVQZj441s3JzouiHfGG2mZ0OJ4p+xAcbZtYZThRmZpbLiaLK7T58rNJNMLM+zomiyjQ1t+6B+L37nq5QS8ysWjhRVJm2txB/bd/R4gXNzDqoQ4lC0jxJuyW9kImdI2m5pA3pfUSKS9I9kuolrZV0WabOrFR+g6RZmfjlktalOvek52qXXIeV5uGxZlZuHT2ieACY2iY2B1gREZOAFekzwLXApPSaDdwLhT/6FJ63fQXwPuC2zB/+e4EbM/WmtrMOMzPrIR1KFBHxJLC/TXg6MD9Nzweuy8QXRMFK4GxJo4FrgOURsT8iDgDLgalp3vCIWJmek72gzbKKrcPMzHpIV/ooRkXEjjS9ExiVpscA2zLltqdYXnx7kXjeOlqRNFtSnaS6PXv2dHJzzMysmLJ0ZqcjgW694DdvHRExNyJqI6K2pqamO5tRFcIPzTaz09CVRLErnTYive9O8QZgXKbc2BTLi48tEs9bh3VQZHKrH5NqZp3RlUSxBGgZuTQLWJyJz0yjn6YAB9Ppo2XA1ZJGpE7sq4Flad4hSVPSaKeZbZZVbB3WCT6SMLPOGNSRQpIeBK4ERkraTmH00p3AIkk3AFuBj6biS4FpQD1wFPg4QETsl/R54JlU7nMR0dJB/kkKI6vOBB5NL3LWYSV05JjBRxZmdjo6lCgi4voSs64qUjaAm0osZx4wr0i8DrikSHxfsXWYmVnP8ZXZVcZHC2ZWbk4UZmaWy4nCzMxyOVFUmf1vnKh0E8ysyjhR9HEnm5o5crzx1Od/WLGh1fxiI2Ibm5q7u1lmVkWcKPq4T3xzNZfctqxDZY+eaALgrxav784mmVmVcaLo41a8nH+xemNzsGHXYQCOnigceWSPQMzM2uNE0Q/83j/7KXdm1nlOFP1AsVt3HHCnt5l1kBNFP3XjgrpKN8HM+ggnij7sjQ72NbQcUGQPLOq2HuiGFplZNXKi6MPePNlU6SaYWT/gRNHHzJq3ituXeHirmfUcJ4o+5olX9/DAT7Z0qm7bLu03jjdy8M2TXW6TmVU3J4p+IN42UTDlb1bw3r9+rKebY2Z9jBNFH9bVB9YdPvbWBXgb9xwpQ4vMrBp1OlFIerekNZnXIUmflnS7pIZMfFqmzi2S6iW9IumaTHxqitVLmpOJT5T0dIp/W9Lgzm9q//bUhr2s2rK/6Lzfv28lV931RA+3yMz6ik4nioh4JSImR8Rk4HIKjz39Xpp9d8u8iFgKIOkiYAZwMTAV+LqkgZIGAl8DrgUuAq5PZQG+lJZ1IXAAuKGz7e3PmpqDP7i/9NXZz28/2IOtMbO+plynnq4CNkbE1pwy04GFEXE8IjZTeKb2+9KrPiI2RcQJYCEwXYVHtX0AeCjVnw9cV6b29nlzn9zIf7nj31vF/vTB54qWdYe1mXVFuRLFDODBzOebJa2VNE/SiBQbA2zLlNmeYqXi5wKvR0Rjm7gBX/7BK2+LLXn+ZxVoiZlVuy4nitRv8JvAd1LoXuACYDKwA7irq+voQBtmS6qTVLdnz57uXp2ZWb9SjiOKa4FnI2IXQETsioimiGgG7qNwagmgARiXqTc2xUrF9wFnSxrUJv42ETE3ImojorampqYMm2RmZi3KkSiuJ3PaSdLozLwPAy+k6SXADElDJE0EJgGrgGeASWmE02AKp7GWROGWp48DH0n1ZwGLy9BeMzM7DYPaL1KapKHAB4E/yYS/LGkyhcu7trTMi4j1khYBLwKNwE0R0ZSWczOwDBgIzIuIlntUfBZYKOkLwHPA/V1pr5mZnb4uJYqIeINCp3M29rGc8ncAdxSJLwWWFolv4q1TV5bR2NzFq+3MzDrIV2abmVkuJwozM8vlRGFmZrmcKMzMLJcThZmZ5XKiMDOzXE4UZmaWy4miStz35KZKN8HMqpQTRZW4Y+lLlW6CmVUpJwozM8vlRGFmZrmcKMzMLJcThZmZ5XKiMDOzXE4UZmaWy4nCzMxyOVH0Ie//8uOVboKZ9UNdThSStkhaJ2mNpLoUO0fSckkb0vuIFJekeyTVS1or6bLMcmal8hskzcrEL0/Lr0911dU291Wv7T9a6SaYWT9UriOKX4uIyRFRmz7PAVZExCRgRfoMcC0wKb1mA/dCIbEAtwFXUHj06W0tySWVuTFTb2qZ2mxmZh3QXaeepgPz0/R84LpMfEEUrATOljQauAZYHhH7I+IAsByYmuYNj4iVERHAgsyyzMysB5QjUQTwmKTVkman2KiI2JGmdwKj0vQYYFum7vYUy4tvLxI3M7MeMqgMy/iViGiQdB6wXNLL2ZkREZKiDOspKSWo2QDjx4/vzlWZmfU7XT6iiIiG9L4b+B6FPoZd6bQR6X13Kt4AjMtUH5tiefGxReJt2zA3Imojorampqarm2RmZhldShSShko6q2UauBp4AVgCtIxcmgUsTtNLgJlp9NMU4GA6RbUMuFrSiNSJfTWwLM07JGlKGu00M7MsMzPrAV099TQK+F4asToI+NeI+IGkZ4BFkm4AtgIfTeWXAtOAeuAo8HGAiNgv6fPAM6nc5yJif5r+JPAAcCbwaHr1O4W+fDOzntelRBERm4D3FonvA64qEg/gphLLmgfMKxKvAy7pSjurwau7jlS6CWbWT/nK7D6iqdlHFGZWGU4UfUTgRGFmleFE0UfUbTlQ6SaYWT/lRNFH3LZkfaWbYGb9lBOFmZnlcqIwM7NcThRmZpbLicLMzHI5UZiZWS4nCjMzy+VEYWZmuZwo+oCeun3H0RONPbIeM+tbnCj6gHlPbe6R9Xxr5Ws9sh4z61ucKPqAHQeP9ch6fD8pMyvGicJO2XXoOMdONlW6GWbWyzhR2Cn3P7WZG+Y/035BM+tXnCislR/X76t0E8ysl+l0opA0TtLjkl6UtF7Sp1L8dkkNktak17RMnVsk1Ut6RdI1mfjUFKuXNCcTnyjp6RT/tqTBnW1vX+a+AzOrpK4cUTQCfxYRFwFTgJskXZTm3R0Rk9NrKUCaNwO4GJgKfF3SQEkDga8B1wIXAddnlvOltKwLgQPADV1or3XQqs372y9kZv1GpxNFROyIiGfT9GHgJWBMTpXpwMKIOB4Rm4F64H3pVR8RmyLiBLAQmC5JwAeAh1L9+cB1nW1vXybUo+v76D/9tEfXZ2a9W1n6KCRNAC4Fnk6hmyWtlTRP0ogUGwNsy1TbnmKl4ucCr0dEY5t4sfXPllQnqW7Pnj1l2KLepdKnnhaueo3Faxoq2gYzq5wuJwpJw4CHgU9HxCHgXuACYDKwA7irq+toT0TMjYjaiKitqanp7tX1mJNNzUyY8wj/8uMtFW3HnO+u41ML11S0DWZWOYO6UlnSGRSSxLci4rsAEbErM/8+4PvpYwMwLlN9bIpRIr4POFvSoHRUkS3fL/iaBjPrDboy6knA/cBLEfGVTHx0ptiHgRfS9BJghqQhkiYCk4BVwDPApDTCaTCFDu8lERHA48BHUv1ZwOLOtrcv8lgnM+sNunJE8cvAx4B1klrOS/wlhVFLkyn8ndsC/AlARKyXtAh4kcKIqZsioglA0s3AMmAgMC8i1qflfRZYKOkLwHMUEpOZmfWgTieKiHgKig7HWZpT5w7gjiLxpcXqRcQmCqOi+qXdh3rmHk9mVhmrNu/nrsde4f/98RWcMbD3Xv/ce1tmfPjrP6l0E8ysG31m0Rqe3ryfnT1048/OcqLoxQ4f8/MhzKzynCjMzCyXE4WZmeVyouglNu99gyPHe8+pppNNzZVuglnViz4yBt6Jopf4tb/7Eb9338pKN+OUP55fV+kmWBsnGps58MaJSjfD+iEnil5k7faDp6bfPFHZq7KfeLX67pnV1930r89y6eeXV7oZ1g85UfRSa7a9XukmMOnWpU4YvcjyF3e1X8j6FPXsjaE7zYmil/rMosrfhO9kU3DPig2VboaZVZgTRS80c94qdvSSC3C27jta6SaYVS13ZlunPdmLTvfsPXK80k0wswpzojAzs1xOFNZhNy7wkFnrfQ6+eZIPffUpNu45UummVC0nil7m0LGTlW5CSd0x6qa5uY+cpO1Foq+c2O4hP3x5F+saDnrgRTdyouhlfun2xyrdhB6z/MVdnP+XS3ll5+GyL3v9zw7yw5cLia2xqZk5D69l6743yr6envKz1988Ne080VrL/ugjI037JCcKOy3l/G922fqdADy/vfzXjPzGPU/xRw/UsXLTPi689VEWPrONT3+7vEOOl67bwYQ5j3R7h//Og8f4b3f+8NTnvp4nTjY1c7yx/BeUHjvp2850l16fKCRNlfSKpHpJcyrdnnLasOswP924j397ru88Cvx/P7S2bMt6aPV2AAZ041VHM+a+dVuU5157vaz30/rz7zwPwO/+00/LtsxidrZ5gNWqzftZvXV/t64zT0S0ey+wtdtfZ8ve4kdw19z9JO/+Pz/g4JvlOc363WcLvz8/WL+zV90vrZr06kQhaSDwNeBa4CIKj1m9qLKtKp8P3v0k19+3suz/6Xanh1Zv5/xbHmHCnEf4/tqfcffyV2lqDm5cUMdXV2zgzkdfprEDNxTMPr3v899/kRONzUyY8wgf+upTp+IRwRe+/yL1u0/v1NRzrx0oOe+LS186rWWV0tjUzNF0m5WNe8p/SuuBH29m2/7CNSx7Drc+Yrn+vpX89r3dm5zy3PHIS0y69VGOnSx9VPCb//hjrvy7HxWdtyklkPf+dXlOsz5Vv/fU9JE+/gyX5uYoedR+5Hjjqe9ET1Nv7hiT9F+B2yPimvT5FoCI+JtSdWpra6OurveMzokIJPHyzkP83397gQEST2+u3H+D3WHgANGU6ZSeevE7mfG+cVw6fgS3LX6BL/7We/i5wW89dff1oyeY/Ln271k0avgQdh1664/kY//r/bxr1FlA4Q/1oDaPjmxsaubCWx9td7krb7mKUcOHcPREEwMHiIED1OoxlC0/s+bm4GRzM03N0ar9EcHEW1o/ufeu33kvv3352JLrfP3oCX7nGz/la79/2altyDp2sonjJ5v5l59s5u///a1O2TV/9cGS+2rTF6chgcp4RPbcawfYcfAY094z+m3zIoKm5njbPv6nj13Oj17ZwyVjhnPNxe/k4JsnuequJwB46XNTOWOgGCCx98hx/uw7z/MfG/a2qv8HU8bTcOBN9h45wa2/8Z/5hXN/jpHDhjBAhZ9NMVv3vcFnFj3P6q2t/yk476whNDYHF//8cP7+dyfT2BzUDBtyWvupsamZgB55NGntF5az98gJFt/0y7x33NlcddeP2LrvKPVfnMa+I8f5T2cMZOiQwnfvQ199inUNB9n4xWmn9suJxmYGDRADSuyn0yFpdUTUFp3XyxPFR4CpEfHH6fPHgCsi4uZSdTqbKBY9s425/7GpQ2U7ss8C2NQN/2laceedNYTdh7u3r2DIoAEcb8w/Who74kzOPGMgB46eLNl3cX7NUOCtztdjJ5tpyHRWn64LzxtWcl797iNIcEFN6TIRQQBNzdHqSvy2y63fXdnhp2eeMZA3c45i2nN+zdCipznb/j63PUIstn/rdx9h5LAhDD9zUKvf8wtqhp5KSNnltlpD5sOmEqfnLqgZ2qodF543rNX+P3/kUAYM0KlYSxs/ddUkPvTeny+6zPbkJYpBxYJ9jaTZwGyA8ePHd2oZI4YO5t1F/tMrvdL2i3RXonjPmHewruGtO822/Y++vxg0QDQ2ByOHDebS8WezbH333jTv3KGD+Vk7t1YZIHFBzTCaIloNJx45bDB7jxRuEX7R6OGFvxUBCPYePt6lRPGuUcNQiS/koAHi5Z2HS363gyjUVaE9LYmi5qwhb6tT6UTxnrHvYNAA8ZON+zpVf9J5wxg4QMX3VSbU8PqbpzrGJ503rOgRYP3uI+w9cpwrJp7D8ZToJ44cyi++c3jJ5WbX2pJMSiWKX3zn8FOJ4hffeRbn1wyluTlOlT/45kmmnH8u9buPMOHcnzv1s3rHmWfk7YJO6+1HFH3+1JOZWV+Qd0TRqzuzgWeASZImShoMzACWVLhNZmb9Sq8+9RQRjZJuBpYBA4F5EbG+ws0yM+tXenWiAIiIpcDSdguamVm36O2nnszMrMKcKMzMLJcThZmZ5XKiMDOzXE4UZmaWq1dfcNcZkvYAWztZfSSwt91S/Yv3SWveH615f7TWl/fHL0RETbEZVZcoukJSXakrE/sr75PWvD9a8/5orVr3h089mZlZLicKMzPL5UTR2txKN6AX8j5pzfujNe+P1qpyf7iPwszMcvmIwszMcjlRJJKmSnpFUr2kOZVuT3eStEXSOklrJNWl2DmSlkvakN5HpLgk3ZP2y1pJl2WWMyuV3yBpVqW253RJmidpt6QXMrGybb+ky9P+rU91y/es0m5QYn/cLqkhfUfWSJqWmXdL2rZXJF2TiRf9HUqPCXg6xb+dHhnQa0kaJ+lxSS9KWi/pUyneb78jhccg9vMXhVuYbwTOBwYDzwMXVbpd3bi9W4CRbWJfBuak6TnAl9L0NOBRCg/omgI8neLnAJvS+4g0PaLS29bB7X8/cBnwQndsP7AqlVWqe22lt7kT++N24M+LlL0o/X4MASam35uBeb9DwCJgRpr+BvA/Kr3N7eyP0cBlafos4NW03f32O+IjioL3AfURsSkiTgALgekVblNPmw7MT9Pzgesy8QVRsBI4W9Jo4BpgeUTsj4gDwHJgak83ujMi4klgf5twWbY/zRseESuj8BdhQWZZvVKJ/VHKdGBhRByPiM1APYXfn6K/Q+k/5Q8AD6X62X3bK0XEjoh4Nk0fBl4CxtCPvyNOFAVjgG2Zz9tTrFoF8Jik1el54wCjImJHmt4JjErTpfZNte2zcm3/mDTdNt4X3ZxOpcxrOc3C6e+Pc4HXI6KxTbxPkDQBuBR4mn78HXGi6J9+JSIuA64FbpL0/uzM9F9Ovx0O19+3P7kXuACYDOwA7qpsc3qepGHAw8CnI+JQdl5/+444URQ0AOMyn8emWFWKiIb0vhv4HoXTBrvSITHpfXcqXmrfVNs+K9f2N6TptvE+JSJ2RURTRDQD91H4jsDp7499FE7FDGoT79UknUEhSXwrIr6bwv32O+JEUfAMMCmNzhgMzACWVLhN3ULSUElntUwDVwMvUNjellEZs4DFaXoJMDON7JgCHEyH38uAqyWNSKclrk6xvqos25/mHZI0JZ2fn5lZVp/R8gcx+TCF7wgU9scMSUMkTQQmUeiYLfo7lP7zfhz4SKqf3be9Uvq53Q+8FBFfyczqv9+RSumnOFYAAADMSURBVPem95YXhZELr1IYuXFrpdvTjdt5PoURKc8D61u2lcK55BXABuDfgXNSXMDX0n5ZB9RmlvVHFDoz64GPV3rbTmMfPEjhdMpJCueHbyjn9gO1FP6wbgT+kXRha299ldgf30zbu5bCH8LRmfK3pm17hcxonVK/Q+k7tyrtp+8AQyq9ze3sj1+hcFppLbAmvab15++Ir8w2M7NcPvVkZma5nCjMzCyXE4WZmeVyojAzs1xOFGZmlsuJwszMcjlRmJlZLicKMzPL9f8BhT8dFxXbqf0AAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7Hpx1nVWO1HA",
        "colab_type": "code",
        "outputId": "d7eaf6f2-f8c0-4c97-c828-ed60c0afd76e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 283
        }
      },
      "source": [
        "plt.plot(p_loss)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[<matplotlib.lines.Line2D at 0x7f0bcb614ba8>]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 124
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD4CAYAAAAD6PrjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dd5xU9b3/8ddnG0tdmkvHXXoTFBYpIgoiVcXYrjHWqPyiomJiDEa98VpiSbn5GTUJPyWJJl4TYzREFMWo8RoDCokFjGWjqGBDUUTp8P39MWeH2d2Z2WlnzpT38/FAZ77nzJnvnJ05n/Pt5pxDREQEoCToDIiISO5QUBARkTAFBRERCVNQEBGRMAUFEREJKws6A+nq2rWrq6mpCTobIiJ5ZfXq1R875/Zrmp73QaGmpoZVq1YFnQ0RkbxiZm9HS1f1kYiIhCkoiIhImIKCiIiEKSiIiEiYgoKIiIQpKIiISJiCgoiIhOVcUDCzmWb2mpnVm9nCoPMjIpJpW7bvYuvO3eHnq9ZtYu/e0DIG6z7+kr/VfxxU1nJr8JqZlQK3AUcC64HnzWyJc+6VYHMmTW3euovWFaVUlIXuKz7+YgcvvPMZRwyt5o2PvgBgULf2QWZRkrB91x4e+OcGZo/oQVWb8kDz4pzj1Q+2MLRHh6Rf+/n2XVx230uM79eZnXv2svTlD7j5+JFs+nInndtWUNu1Le99to1PvtxB96rW9OrYutFrN27ZQd/Obdi4ZQc9I7Zl2gFXP0b7yjL+8q3DmPKDp/hy5x6+PWMwn2/bxS+efrPRvj8/dQwzR3Tnn+98SlXrcm59sp4ubSu4Ys4wX/JmubTIjplNAK52zs3wnl8O4Jy7IdZr6urqXCojmo+59RnqP/oCC73Pvjx4/yktMcpKSigvNcpKjXc3bWt2jOWXTKa2a1scUF6ac4UuX9UsXMrkQfvx67PGcs9z73DFA2sAuGjqAG55oh6Au88+mLE1naksL034uB99vp1bnniDC6cOZMv23dyz8h1Om7A/969eT23XtlS1LmfKkGpKS6zlg3nWf7qVKx5YgwOefn1jwq/7wzcmUFfTuVHarj17+fHy19n0xU5+t+pd5hzQg6Uvv9/stc98ZwqTbnqyWfqA6nb8z7nj2a99q4TycN+qd9myfTfXPPQKV84Zyo8ee51tu/bE3L+yvITtu/YmdOzRfTvy01NGc8iNT0Td/quzxnL44OqEjpWs3Xv2hn83u/fsZeKNTzBvcj8m9O/Cyjc3cc1DrzB9WDcWnV4X9zhbtu/igKsfA+CokT146KXmf4tMObBPR44Z1ZNB3dpz6p0rk379A+dP5KC+nYDQ7wegY5tyPtu6K+U8vX7drPCNWbLMbLVzrtkJzrWgcAIw0zl3jvf8NGCcc25+k/3mAfMA+vbtO+btt6OO1o7r53/9Nx9v2UHDp284DQ6Hc7DXOXbtcezes5fdex0P/HNDQse9dPog5k8dmHR+8oVzjj+98B4LfvdCQvu3KivhtetmJXz8/t99mD17W/5ODuvRgT9fOKnF4OCco/byhxN+/6YeunASI3pVhZ//4q//5oZHXk35eA2WzD+EXzz9Jl8/pJYx+4cuFG9u/IJW5aWN7l4bLh5BuueccUwc0DXuPs+v28RnW3fxjd+sDv/9mp67SEOueiTh4AWw7sY5UdPf3PgFU3/014SPE7RXr53JkKuWZex4L189nfaVqZXsCiooREq1pJCqZ+s/5pQ7Wr5LOHNiDXNG9uC1D7ZwzIE96ZDiHy5XvPPJVib/oPmdbyJWXTmNru2a3xm/vH4zL7z7KSve2sTSFO7wqlqX8+L3psfc/sjL73Peb/+R9HGbOnNiDVcfMxzw5yLdq2NrNny2ryQaeQHMhaAAoVLTgy9s4IQxfRjZq4r6jV+w+Jm3GNqjAx3blHPxvdFvEmJdzJP9XF8b15frv3JA+PmyNe8zrrYLB127PKnjFJr662dRlmItRaygkFNtCsAGoE/E895eWs6YOKArb90wu8W7z189u45fPbsOgCsfXMPqK6fRJcqFMZd9tnUn7SvLeX7dJk5etCLl49Rd93j44nDqHSs5qG9HvjV9MEff+kxa+du8bRdvf/Il+3dp22zblu27MhIQIPS3PP/w/lR3qMzI8ZqKDAgAH2zeTvcqf94rVSf8/O8A/GbFO0m97g+r1zO8Z4dG7QOLn3kr6ff/7cp3wkHhNyve5soH1yR9jGwb3rMDa9/73Nf3SDUgxJNrFeHPAwPNrNbMKoCTgSUB56kZM+OtG2YzrrZzyzt7xlz3ON+4e7WPucqsLdt3ceA1y7nqT2vSCggNTr1jJZu37uKZ+o/56RP1rH770wzkEg77wVPhXhuRkq3iOX3C/nG3H/z9v/CTx19P6pipevWD0IXkyx27W9gzcX9bOJV7543P2PESdel9LzLr//4vzjkm3fQENQuXcs1DqfUb+cc7n+Kc43fPv5vhXGbWA+dPZN2Nc1h60aExS0q5LKeCgnNuNzAfeBT4F/B759zaYHMVnZnxu/8zgXU3zuG08fEvKA2Wrf2ACzJ09+q3BV51wD0rk7szjOWZ+o8Zdc1j4efH/+zZjBwXoN93H2b9p1sbpX24eXtSx7hm7gjunTeeO+I0bP7k8Tdibnvuu0cw54AeSb1nLA2NpcO/92jMfb49Y3DcY8yfMqDR814dWzO+X5f0M5eisdf/hfWfNu+skYzjbn+W2ssf5uUNm1vc99gDe3L+4f2pLC/h5uNHcu+88ZwwpneLr3vtupl847D+Kefxt+eMCzcm56ucCgoAzrmHnXODnHP9nXPXB52fRFw2M/4PNNLSl9/PmXrieP7y6kdBZyEpd694G+ccd694my937G4x/3+eP4m3bpjNgmkD+eP5EwEY368L04Z1Y1Sfjkm999KLJlHdoZLbvjaam08YmfJnaPCH1etb3Gdoj/jdfS+NETTuOXdc1PS/fvvwFt8zHR9/scPX4zd13OjeXDZzCK9eO4uTxvZhfL8u3HDcAcwZ2YNbTzko5utalZVy0REDOHRg/IZ1gG8eOYgDm3xXDonSIP/jk0Yl/wEiPLtwaqPnD104CYD2lf7U/udcUMhH7SvLWXfjnKSKims2bGbjlh0J9bTx05oNm1m2xr9ufNnyi7++yW9WvsNVD67hK7f/LeZ+q6+cxrMLp3JA7yrMjAXTBjG6yZ3dg16QaEnD33x4z309bE6q6xNz30MGZO5O3cy46IiBzUoEQLhH1nNXHMH8KQOYMbxbeNvE/tEvdvt3acub35/N35pcgPJVuygXzPLSEm47ZTRHjezJ6iunxXxtm4oy7j47evCMNKhbO8pLW+4a3aYitYv3KeP68uq1M+nZsTWvXDODKYP3Y87IHgzt0YGu7Sq4Zu7wlI7bklxraC4aR/001Mga2bMlyHzkY91nU1d5jY+vf/hF1O2PXTI5ocb+yHEr6Xrowkl0blsBwLmH9uNv9Z+kfKwF0way5IX3ePPjLyk145tHDgLg1ifrG+134dRQoKhuXxmzxBBNSYlRUQDjbc6cWMOo3vFLe13ateLoUT3584vvhdN+8h8HNtpn3Y1zWP7Kh5x7V/TejaFxTPvOV6zBdpFBOVF3ff1gJg/at1Jmm4oyfnnWweHnq648MuljJir/vwE55ulvT0lq/2VrPvApJ8nbtnMPNQuX5kz1VsPFrcH8KQOYNjT5H1iDIEZYj+hVFR4Z23+/dmkda8G0Qdw7bzznTKqNWk3RoCEIpSKD8TAQ354xmKuPGZ7Q4MbDBzVenvjYg3o12+fIYbG/b4cO6hoOBAf26ciS+YdE3S+Zm4zLZg5m3Y1zGgWEbFNQyLC+Xdoktf8Hn2+n/qMtPuUmcX9YvZ5f/31d0Nlo5FvTBzO+374eXpfOGEyq42quTbKofd2xI+JubxhwloyOGZg+orpDJVceNSzuRa+lu/0rZg+Nua1NReKjz3NRqyRG9x43el8QiDcq+KdfPYhro3wfWpWVsnDWEP7n3PE8eMEhGZnV4OSxfdM+RroUFHyw7sY5nDmxJuH9p/34af8yk6BL73uRGzMwUhdIqqtuS245+SAumTaIt26YDcDuFNtgTptQk9T+pxwc/8f5y7PGxtw2tiZ6wGhfWR73ghwpmeA3pHt7BlS3447T6xhQ3Y5jDuwZd/9zJ/fjhydGb/xsU1HW6GKZb5IJ1pF38DVxbuaOHtUzZg/D8tISJvRvua3o6FHx/ybhPCW0l78UFHxS27X5gKp4YtVb+ml7nDl0UrVswaHcc+549k+yxNRUQ8+K6g6VXDxtYPgH3LRROBHJ/i0gVL8eT7wR6mP2jx0UT5+YWPflO5sM8IrVawhg2YLJPP7Nw5g2rBuPf/OwhBo24326G447IM7W3DRtaDX1189KuTtoWYm/l8KmVaHnHlobfhzZDToXqu8UFHzytXF9uXbu8IQnPlv+yoc+56ixNRs2Z3QOllkjurPotDEM6d6B0hLj0QWTM3bsSPOnDuD+8xLrHdQg2f3TNW9yv5jbEr34XLf0X42e18UJNKmILIeM6NW4gdTvC6QfzCyt0b21+yV/45CMpu1ZkdeFacO60d0bLd86B6rv1PvIJ2WlJZw2oYYT6/pw3dJXkp4ewE+3P1XPr70pODLlZ6eOafS8sryU8f06s+LNTSkdL9YNU2mJJV2fn07DazTLFhya8muTmNy1kVRnwkxEtHmpsq2irISdu0MT5C2/ZDJH/ndyVaqpntcG/xGjK3E0L/7ndCor0vt7HDe6N+99tp1LvB5kf7881BU4kz3fUpV/twR5prK8NOF+yp9vD02hu2rdJjZvS3063ZbcvOw1Pvw8M4OJendqzVdj1L+PbKFbYDwnt1Cnnw2x+rL3amGe/XiDinLhRw+Ng26Qc2LefbbXzdIRHnE8MIVeYul+XxIp0f/itDE8eenhVLUpp1VZenf0rcpKuPqY4VS1DlVDmlnOfDcUFLIgkYYogJFXP8ZvVrzNCT//O2f/6nmfc5UZz3xnasw66FS/4g9dOInLZw1JPVMZEmtMw94WZnzOt7U19gYYFcbVdgnn4YcnjgqPl1lx+RFJHad3igvivHH9LB44f2JCC/rMGN49pfapaHIlAESTX9/ePDVlcDX/vCqxwSYNsz/6NbtirAVV/DCsZ/IrZ0Gob3+mfjSJTFeQrD0JXERbJ7GwUBAiT2/fzul1CoiUbAeDhq61w5usu9C9qjLhEeALZw1hQHVqY0DKS0vyfq6iTFNQyJJOSdZrOzJ/91azcGmzaZrTccm0QXG3zz0w8a6Nfs3gme5N8OIzm0+Ql8id9aorp8Vd6yFoJRFRoUsG21yWXTw5qYBYWmLcf94E7ooYrdtg8Zmxu/1G+sZh/XP6zjuaXFrHpik1NBeB7bv2MPWHT2X0mJmeFmN8vy4svWgSm9NYmjCadKtGonU9TeSYbVvF/mkdN7oXf/xHsMuEzD6gR3j1vDZx8pqs1hWldGxTzrbNiXd3jtWFN916e0mNSgo5bNvOPWzbuYcdu/ewZ69jx+49vJfCnf66T77kvSSnkg7C8J5VLS77mCw/bshaalNoyY9POjDhwUx+qSgrCfeVb3qPneo9d8Ogt2zeBGeylJNNuVtOUFDIqiHdE+9V4RwccPWjDP3PZQy+chkX3ftPLvndC0y88YmEA4Nzjvn3/IP3MxgQqtu3ajTwpiXh3iUZ9u/vz05ov3RrFXp3al5HnkibQkty4WLW8DFSPUePXHwoNx+/b6rwIKbImD+1+Syx+aB9BktnmaagkEXJDqKKnNJh6Uvv8/DLocnzLrin8UI996x8h01f7mz2+p8+Uc9DL73PWb/MXE+mGcO7c8WcYQnvXxNlqcymvn5I4kGmQSITngExp3NIVLRlMaOt9Jas78wcwn8FODsu7LtbLUkxKgzt0YFBETc65pUxMtkedtPx8UdXN3TpzDe53AaioJBFbVuVpTSRWlNbd+yrr33jwy1894GXufjefzbbb+17La9Q5bdEbqpPGtvyilip6pliV8V4MlE90rqilDOSmB/LD5n4HJGXtoY4fe3c+JMJJiPW+hTh98/da2veUlDIskz0Ooi8E9vhjQL9+IvmJQXzYXqtZJd0TOSucUj31LquBqWyPLs/mwXTBvpy3Ez3cGvowz99ePeMHdPMaJsDUz8Uk9yt2CpQidY8xNstMq403ClFCzaZvot6dMFkBifRLgLQJ0qdfL6r7tC8SslPZenO4RBDwyC7ZI7fobKMz7fvDj+P/I7F63GVjkSrCiUzVFLIskS7SDbMAxPNGx99Qc3CpWzftSdcGtjw2TbWNFnQPNNBIdmAAC3PNurXkoKFxK/ePBdOHcDXD6lNaoqIIEZrx5smxo/SsN96d8p8lWYmKShk2e49mfuFD7lqWfjCv2X77vDSmjt372XyzU/yRAuL1+eChvluilUinQ/8Wsa7fWU5/3n0MCqbDDaL93ZNbzQiL8rJNp5edVRiHRb6V8furDAxg+teZ8tvz2l5/ecgKShk2e50O7m34Pl1m1j51ie8s2kr23f5+16ZkOqi5oWipcn1wJ/R7fEkU1sTOUVGMiHhlHF9+dq4xEoo8Rqbq9tntyovE/ZPoEdekBQUsqxHVWaLjk1vzk78+d857c7nMvoeAKdPSGxxGElO5AU/V3rSmBlXH53YXXxVm3KO9VZ6Syb/3//KAc1KKJIbFBSyLNP1idmqU70mg90MJbqXYsyVFMQ0Oa1SuGBHBoU7z2g+Z5TkBwWFLLs4w90Lg5z2OFWJ3oUWm/Yxlvj0c4Gd5DW/CYn2DTxiaDf/syK+yKVvW1HIdB3odUtfyejxIk1McB2IljSdLfTMFEYwp+L40bnfiJ1ITO+XoTn8/ZaPPYGkOQWFPPe3+k98O/ZNEfPapKOqdTmvXTezUdo5k2qTWgIxFX0653bXv0TNHJG5wWDpitZukIeFVYmjuLt+SFylJcYD509kTQYW/Gk6DfKVCXZHTMf5h+fnZGlNBTFPTv/9mi9aM6pPRxbOHMLa9zY3GtneEBNypaE8V7WvLKNTm+AnQmyJgoLE1KlNBT07ts7blakyVRc/c3h3lq39ICPHiqVHlIn3gnRwbWeevPRwpkSsw3H32QfTobK82fKyyUzd8tglk9my3b/1x3NZrI4EuUbVRwE4Z1J26tTT0bailNaacwaAb02Pv8JcOjp7U2hfdIQ/8xulI9n1iBMp0Qzq1j7mojqxFEr1lJnl9OyoDRQUApCNqpN0pdIlsVD5+UOuLC9l3Y1z+GoSU03kmkSv2dcem9luzV3btcro8STEt6BgZj8ws1fN7CUze8DMOkZsu9zM6s3sNTObEZE+00urN7OFfuUtlxw1skfQWYgqmQWBismyBYcGnYXc07BYTwu7pbqwTKyYvPSiSSkdT+Lzs6SwHBjhnBsJvA5cDmBmw4CTgeHATOB2Mys1s1LgNmAWMAz4qrdvQfvRSektAuOXLnlwF3b6hP2ZOqQ6a+/Xr2vbvJvmOxsaRmX7VaCKVX3ULcuz1RYL34KCc+4x51zDHLsrgIZO43OBe51zO5xzbwH1wMHev3rn3JvOuZ3Avd6+BU19u1N3zdwRLD5zbNDZKBrtYsxTFV7Ws4Xvch5UpwvZa1P4OvCI97gX8G7EtvVeWqz0ZsxsnpmtMrNVGzdu9CG72ZOrP5RMLAZUKHL1b5RtLU2D3vQ8nTzW33Eo4o+0uqSa2eNAtJE1Vzjn/uTtcwWwG/htOu8VyTm3CFgEUFdXl9dXL11v0te6vJRtu/a0vGOKarq05aS63nw9D3qN+aVh0rtoYt0/9OmcmQWWsj1LbLFLKyg456bF225mZwJHAUe4fbeeG4DIW4jeXhpx0kViGlDdjpebLDA0IcllQ+MpLTFuPiE3236y5ScnHxRz24l1vVm29gNG9q5qlB5ZcigtMQ4btJ9f2ZMM8rP30UzgMuAY59zWiE1LgJPNrJWZ1QIDgeeA54GBZlZrZhWEGqOX+JU/KRzFvlBP0I4Y2o11N86hd5ylV5+69HA65sFoXvF3RPOtQCtgudfPe4Vz7hvOubVm9nvgFULVShc45/YAmNl84FGgFFjsnFvrY/5yQq6uP5tPBfbTJ+zPAb2rOO72Z8NpagcInpql8pNvQcE5F3PiGefc9cD1UdIfBh72K0+5KB9GOOY6M2N0307cfPxILrv/paCzk5aRvat4af3mlncU8YlGNAfkV2eNZcn8Q4LORmx5eJd3UkRvl3yNtRWljX+SNV0y01ibz1TiyC4FhYAcPriakb07trxjQPzo8dG3cxvO0LKeScnn62Fkw7Iu7PlDs6RK1jx92RTf32PK4P148rWNeTsoMF9LONGM6FVF706tWf/ptrSO07CWcyaOJS1TSUEaGVeb3AyWueasLK3qli0FFCNSNmN4d74zcwi/OuvgoLNSFBQUpJFenQpjtbJ8la8lnFgyUfIpLTHOO7w/7VKcUE+So6AgjTRclPK9Djhvq2HyNd8tyEQbVVlpgZ6cHKOgIAUlz2NZs5gwPoMjs4OQyZKP1k/IDgUFaaThDjvfSwqF4pxD+wWdhbQM7xmaajxTq/ipi67/VEknBSXfZ3dtWu2VowPeE/ajk0ZxzqG1VLfX2gf5QiUFKSgNIUEjxXNDm4qypNdklmApKOSAX5w2JugsFJx8DQmF1vtI8o+CQg7o0jZ3Zo9suCTl7Rz2eZrtBtOHdws6CzlNJUD/KSjkgFy6jhVKQ3O+XjvOnFjT6LkugpJtCgrSSL5XX+RtCcfTNAjk919D8pGCQg5oXZ6Z7nqZUNWmHICeHfN7ZLMupoVJf1f/qUtqDhjRq6rlnbJkfL/OHNinI1OHVAedlZTU1XSmZ1UlC6YNCjor4oNcXZSqkCgoSDOzD+gRdBZS1qGynGcvPyLobIhPStTG4jtVH0kjkwdqcXXJXWNqOgWdhYKnkoI0Ulaq+wTZ59pjR7B9556gsxF29dHDuWflO0Fno6ApKIjksKBrS04bn1sr5VWUlWCWe/kqJAoKIpJX3rphTtBZKGiqK5CwI4dpNK1IsVNQkLBZI7oHnQVpIt8HE0r+UVAocv972RTmHtgTCL7+WkSCp6BQ5Pp0bqN70RxzyIB9q60pUEu2KSjkiPvPmxjYe+f3bEGFZ2L/rkFnQYqYgkKOGLN/8INyVH8tIgoKQmdvPYc2GVpHV0Tyl8YpCN+ZOYQB1e3UJVVEFBQEKstL+do4jRAVEVUfFbW1/zUj6CyISI7xPSiY2bfMzJlZV++5mdktZlZvZi+Z2eiIfc8wsze8f2f4nbdi17aVCoq5Tl1SJdt8vSqYWR9gOhA5reEsYKD3bxzwM2CcmXUGvgfUEeoludrMljjnPvUzjyIiso/fJYX/Bi6jcVf4ucBdLmQF0NHMegAzgOXOuU1eIFgOzPQ5fyI5J5eWZ5Xi41tQMLO5wAbn3ItNNvUC3o14vt5Li5Ue7djzzGyVma3auHFjBnMtErxTNS20BCit6iMzexyINovaFcB3CVUdZZxzbhGwCKCurk4DcqWgVJTtu1czNSpIlqVVUnDOTXPOjWj6D3gTqAVeNLN1QG/gH2bWHdgA9Ik4TG8vLVZ60blyztCgsyAiRcqX6iPn3MvOuWrnXI1zroZQVdBo59wHwBLgdK8X0nhgs3PufeBRYLqZdTKzToRKGY/6kb9ct1/7VkFnQUSKVBB9Eh8GZgP1wFbgLADn3CYzuxZ43tvvGufcpgDyJyJStLISFLzSQsNjB1wQY7/FwOJs5EkkH6hFQbJNI5pFcpjamSXbFBRERCRMQSEHOXWyFZGAKCjkIKe10EQkIAoKIiISpqAgksO0RKpkm4KCiIiEKSiI5KBuHTSqXYKhoJCD1PtI9B2QoCgo5JCKUv05JKQhJmjwmmSb1mPMIc8snMJnW3exZsPmoLMiOUIxQbJNt6Y5pLp9JYO6tQ86GyJSxBQUREQkTEEhB6mRUUSCoqAgIiJhCgoiOUilRQmKgoJITvKigrofSZYpKBSZthWlQWdBkqC5jyTbFBSKyHdnD+H4Mb2DzoaI5DAFhRzkZ3WyRk2LSDy6QhQRVUXkDzU0S1AUFERymOY+kmxTUCgyZao+EpE4dIUoImZQVqJbTxGJTUEhBzmfKpT/Y2wfX44rIoVDQaGItGulmdLzRZ/ObQAoL9FPVLJL37g8deHUAUFnQXy0+MyxLDptDFVtyoPOihQZBYUiYmY4X0dBSKZ0blvB9OHdg86GFCEFhTxxzKieQWdBRIqAgkKemDe5X6Pn6kMkIn5QUMhBU4dUN0vrXlUZQE5EpNj4GhTM7EIze9XM1prZzRHpl5tZvZm9ZmYzItJnemn1ZrbQz7zlsi7tWrW8k4a6iogPfOujaGZTgLnAKOfcDjOr9tKHAScDw4GewONmNsh72W3AkcB64HkzW+Kce8WvPOayqtblbN62i9ISY8/e5o3Do3pXpXTchiEQl04fFH9HESlKfnZcPw+40Tm3A8A595GXPhe410t/y8zqgYO9bfXOuTcBzOxeb9+iDAotOWJot7RebyppiEgUflYfDQIONbOVZvZXMxvrpfcC3o3Yb72XFiu9GTObZ2arzGzVxo0bfch68Pwa1SwiEk9aJQUzexyI1pn6Cu/YnYHxwFjg92bWL8q+SXPOLQIWAdTV1RXF1VP39SKSDWkFBefctFjbzOw84I8udMv7nJntBboCG4DISXh6e2nESZcMKYoIKiIp87P66EFgCoDXkFwBfAwsAU42s1ZmVgsMBJ4DngcGmlmtmVUQaoxe4mP+8kKvjq0BKIsyB86oPh2znR0RKXB+NjQvBhab2RpgJ3CGV2pYa2a/J9SAvBu4wDm3B8DM5gOPAqXAYufcWh/zlxd+edZYXn1/S9Q5cEpVpyQiGeZbUHDO7QROjbHteuD6KOkPAw/7lad8suj0Ou7437eo7dKW/vu1A+DEMb25b/X68D7qQSQimaYRzTlqfL8u3HFGHSURi+KM69el0T6phIQOlaESR/tKTaMtIs3pypBHhvfs0Oh5KgWFsyfV0rq8hFMO7puhXIlIIVFQyCNDezQJCimUFSrKSjjzkNpMZUlECoyqj0REJExBIZ+pnVlEMkxBoUD9+KRRQVwqYtcAAAj8SURBVGdBRPKQgkIei1dQOG5076zlQ0QKh4JCHtMwBRHJNAWFPJZK7yMRkXgUFEREJExBQUREwhQU8pjaFEQk0xQUREQkTEEhj6mkICKZpqCQxyJ7H93y1YPCA9YaFuYREUmWJsTLM88unMp7n20DGpcUjhnVk607d3P7U//mpuNHBpQ7Ecl3Cgp5pmfH1vSMURJoU1HG4988LMs5EpFCouojEREJU1DIY1qOU0QyTUFBRETC1KaQx0oSLCj8ef4kqju08jczIlIQFBTyWFmCUWFgt3ZUlpf6nBsRKQSqPspralMQkcxSUCgCzgWdAxHJFwoKIiISpqAgIiJhCgoiIhKmoCAiImEKCkVAA59FJFEKCiIiEqagICIiYb4FBTM70MxWmNkLZrbKzA720s3MbjGzejN7ycxGR7zmDDN7w/t3hl95ExGR6Pyc5uJm4L+cc4+Y2Wzv+eHALGCg928c8DNgnJl1Br4H1AEOWG1mS5xzn/qYRxERieBn9ZEDOniPq4D3vMdzgbtcyAqgo5n1AGYAy51zm7xAsByY6WP+RESkCT9LCguAR83sh4SCz0QvvRfwbsR+6720WOkiIpIlaQUFM3sc6B5l0xXAEcAlzrn7zewk4E5gWjrvF/G+84B5AH379s3EIUVEhDSDgnMu5kXezO4CLvae3gfc4T3eAPSJ2LW3l7aBUJtDZPpTMd53EbAIoK6urmine9P4AxHJND/bFN4DGlaRnwq84T1eApzu9UIaD2x2zr0PPApMN7NOZtYJmO6liYhIlvjZpnAu8H/NrAzYjlfdAzwMzAbqga3AWQDOuU1mdi3wvLffNc65TT7mT0REmvAtKDjnngHGREl3wAUxXrMYWOxXnkREJD6NaBYRkTAFBRERCVNQKALqpSQiiVJQyGMN1/pzD60NNB8iUjgUFArAmP07B50FESkQCgoiIhKmoFAEXNGO+RaRZCkoiIhImIJCQVBRQEQyQ0EhjyXa1VRdUkUkUQoKIiISpqAgIiJhCgoiIhKmoCAiImEKCiIiEqagUAA0OE1EMkVBIY8ZifU1TXQ/EREFhQJWXqpgICLJUVAQEZEwBQUREQlTUBARkTAFBRERCVNQyGOd2pYDUFleGnW7uqqKSLLKgs6ApO7KOcMY0r0Dhw/eL+5+miVVRBKloJDH2rYq44yJNUFnQ0QKiKqPREQkTEFBRETCFBRERCRMQUFERMIUFApY64pQV1V1TRWRRKn3UQF74PyJPPHqR1SUKfaLSGIUFArYgOr2DKhuH3Q2RCSPpHULaWYnmtlaM9trZnVNtl1uZvVm9pqZzYhIn+ml1ZvZwoj0WjNb6aX/zswq0smbiIgkL916hTXAccDTkYlmNgw4GRgOzARuN7NSMysFbgNmAcOAr3r7AtwE/LdzbgDwKXB2mnkTEZEkpRUUnHP/cs69FmXTXOBe59wO59xbQD1wsPev3jn3pnNuJ3AvMNfMDJgK/MF7/a+BY9PJm4iIJM+vFshewLsRz9d7abHSuwCfOed2N0mPyszmmdkqM1u1cePGjGZcRKSYtdjQbGaPA92jbLrCOfenzGepZc65RcAigLq6OnW4FBHJkBaDgnNuWgrH3QD0iXje20sjRvonQEczK/NKC5H7i4hIlvhVfbQEONnMWplZLTAQeA54Hhjo9TSqINQYvcQ554AngRO8158BBFIKEREpZul2Sf2Kma0HJgBLzexRAOfcWuD3wCvAMuAC59werxQwH3gU+Bfwe29fgO8A3zSzekJtDHemkzcREUmeuTyfA8HMNgJvp/jyrsDHGcxOvtP5aEznozmdk8by+Xzs75xrtkJX3geFdJjZKudcXct7Fgedj8Z0PprTOWmsEM+HJsUREZEwBQUREQkr9qCwKOgM5Bidj8Z0PprTOWms4M5HUbcpiIhIY8VeUhARkQgKCiIiElaUQSHWmg6FyMzWmdnLZvaCma3y0jqb2XIze8P7fycv3czsFu+8vGRmoyOOc4a3/xtmdkZQnycVZrbYzD4yszURaRk7B2Y2xjvH9d5rLbufMDkxzsfVZrbB+568YGazI7YV9NooZtbHzJ40s1e89WEu9tKL8zvinCuqf0Ap8G+gH1ABvAgMCzpfPn7edUDXJmk3Awu9xwuBm7zHs4FHAAPGAyu99M7Am97/O3mPOwX92ZI4B5OB0cAaP84BoSlcxnuveQSYFfRnTuF8XA1cGmXfYd5vpBVQ6/12SuP9jgjNZnCy9/jnwHlBf+YWzkcPYLT3uD3wuve5i/I7UowlhahrOgScp2ybS2jNCmi8dsVc4C4XsoLQJIU9gBnAcufcJufcp8ByQosn5QXn3NPApibJGTkH3rYOzrkVLvTrv4scXwskxvmIpeDXRnHOve+c+4f3eAuhKXh6UaTfkWIMCrHWdChUDnjMzFab2TwvrZtz7n3v8QdAN+9xsutg5LNMnYNe3uOm6flovlcdsrihqgSf10bJNWZWAxwErKRIvyPFGBSKzSTn3GhCS6BeYGaTIzd6dy5F3S9Z5wCAnwH9gQOB94EfBZud7DOzdsD9wALn3OeR24rpO1KMQSHeWg8Fxzm3wfv/R8ADhIr9H3pFWrz/f+TtHuvcFOI5y9Q52OA9bpqeV5xzH7rQTMZ7gf9H6HsCyZ+P8NooTdJzmpmVEwoIv3XO/dFLLsrvSDEGhahrOgScJ1+YWVsza9/wGJgOrCH0eRt6RkSuXbEEON3rXTEe2OwVnx8FpptZJ69aYbqXls8ycg68bZ+b2XivPv108nAtkIaLn+crhL4nUARro3h/tzuBfznnfhyxqTi/I0G3dAfxj1DvgdcJ9Z64Iuj8+Pg5+xHqFfIisLbhsxKq9/0L8AbwONDZSzfgNu+8vAzURRzr64QaGeuBs4L+bEmeh/8hVCWyi1B97tmZPAdAHaGL6L+BW/FmCsjVfzHOx93e532J0EWvR8T+V3if7TUies3E+h1537vnvPN0H9Aq6M/cwvmYRKhq6CXgBe/f7GL9jmiaCxERCSvG6iMREYlBQUFERMIUFEREJExBQUREwhQUREQkTEFBRETCFBRERCTs/wNpvtJnQRIDKQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ho1VsZVS7u7d",
        "colab_type": "text"
      },
      "source": [
        "#Testing"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-HIsYUA6_dUN",
        "colab_type": "code",
        "outputId": "c265c581-2fc6-4057-828f-596844a26d69",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "#prediction algorithm\n",
        "it2 = iter(test_dataloader)\n",
        "precision = 0\n",
        "test_pred_dict = dict()\n",
        "for j in range(len(test_dataloader)-1):\n",
        "  first = next(it2)\n",
        "  item_b,rating_b,size_b,userid_b,idx_b = first['item'],first['rating'],first['size'],first['userid'],first['idx']\n",
        "  memory[idx_b] = [item[0] for item in item_b]\n",
        "  state = drrave_state_rep(userid_b,item_b,memory,idx_b)\n",
        "  count = 0\n",
        "  test_pred = set()\n",
        "  for j in range(5):\n",
        "    state_rep =  torch.reshape(state,[-1])\n",
        "    action_emb = policy_net(state_rep)\n",
        "    action = get_action(state,action_emb,userid_b,item_b,test_pred)\n",
        "    rate = int(users_dict[userid_b[0]][\"rating\"][action])\n",
        "    try:\n",
        "      rating = (int(rate)-3)/2\n",
        "    except:\n",
        "      rating = 0\n",
        "    reward = torch.Tensor((rating,))\n",
        "\n",
        "    if reward > 0:\n",
        "      count += 1\n",
        "      update_memory(memory,int(users_dict[userid_b[0]][\"item\"][action]),idx_b)\n",
        "    next_state = drrave_state_rep(userid_b,item_b,memory,idx_b)\n",
        "    state = next_state\n",
        "  precision += count/5\n",
        "  test_pred_dict[userid_b[0]] = test_pred\n",
        "print(\"p\",precision/(len(test_dataloader)-1))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "p 0.752925353059846\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_stX77i9A1EV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#Getting Cosine similarity of recommended items for a particular userid that has been evaluated\n",
        "from sklearn.metrics.pairwise import cosine_similarity as cs\n",
        "import seaborn as sns\n",
        "\n",
        "def get_cosine_sim(userid):\n",
        "  test_pred = test_pred_dict[userid]\n",
        "  for i,item in enumerate(users_dict[userid][\"item\"]):\n",
        "    if item in test_pred:\n",
        "      print(item,\":\",users_dict[userid][\"rating\"][i])\n",
        "\n",
        "  test_embed = []\n",
        "  for item in test_pred:\n",
        "    test_embed.append(np.array(movie_embeddings_dict[item]))\n",
        "\n",
        "  test_embed_array = np.array(test_embed)\n",
        "\n",
        "  return test_embed_array\n",
        "\n",
        "test_embed_array = get_cosine_sim(userid_b[0])\n",
        "ax = sns.heatmap(cs(test_embed_array), linewidth=0.5)\n",
        "plt.show()\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2MHnXsrP79pN",
        "colab_type": "text"
      },
      "source": [
        "#Saving and Loading Models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YPj1Ejb7A2cn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "PATH = '/content/gdrive/My Drive/RLProject/Models/drravepolicy_net.pth'\n",
        "torch.save(policy_net.state_dict(), PATH)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8HJdozZcBV5a",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "value_PATH = '/content/gdrive/My Drive/RLProject/Models/drravevalue_net.pth'\n",
        "torch.save(value_net.state_dict(), value_PATH)\n",
        "\n",
        "tpolicy_PATH = '/content/gdrive/My Drive/RLProject/Models/drravetpolicy_net.pth'\n",
        "torch.save(target_policy_net.state_dict(), tpolicy_PATH)\n",
        "\n",
        "tvalue_PATH = '/content/gdrive/My Drive/RLProject/Models/drravetvalue_net.pth'\n",
        "torch.save(target_value_net.state_dict(), tvalue_PATH)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FAPPzXDCCA8e",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.save('/content/gdrive/My Drive/RLProject/Models/train_dataloader',train_dataloader)\n",
        "np.save('/content/gdrive/My Drive/RLProject/Models/test_dataloader',test_dataloader)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WU5Ac2dGCTIj",
        "colab_type": "code",
        "outputId": "fb32e68e-d56b-486f-a0fc-86ed072e8230",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 129
        }
      },
      "source": [
        "policy_net = Actor(5500,100,256)\n",
        "policy_net.load_state_dict(torch.load(PATH))\n",
        "policy_net.eval()\n",
        "\n",
        "value_net = Critic(5500,100,256)\n",
        "value_net.load_state_dict(torch.load(value_PATH))\n",
        "value_net.eval()\n",
        "\n",
        "target_policy_net = Actor(5500,100,256)\n",
        "target_policy_net.load_state_dict(torch.load(tpolicy_PATH))\n",
        "target_policy_net.eval()\n",
        "\n",
        "target_value_net = Critic(5500,100,256)\n",
        "target_value_net.load_state_dict(torch.load(tvalue_PATH))\n",
        "target_value_net.eval()\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Critic(\n",
              "  (drop_layer): Dropout(p=0.5, inplace=False)\n",
              "  (linear1): Linear(in_features=370728, out_features=256, bias=True)\n",
              "  (linear2): Linear(in_features=256, out_features=256, bias=True)\n",
              "  (linear3): Linear(in_features=256, out_features=1, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1LyKW_ns7hv5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np_load_old = np.load\n",
        "\n",
        "# modify the default parameters of np.load\n",
        "np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)\n",
        "\n",
        "train_data = np.load('/content/gdrive/My Drive/RLProject/Models/train_users.npy')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bNsZXyGHAid7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.load = np_load_old"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}